{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from pathlib import Path\n",
    "\n",
    "import cv2\n",
    "import numpy as np\n",
    "import torch\n",
    "import torchvision\n",
    "\n",
    "from boxmot import BotSort\n",
    "\n",
    "# Load a pre-trained Mask R-CNN model from torchvision\n",
    "device = torch.device('cpu')  # Change to 'cuda' if you have a GPU available\n",
    "segmentation_model = torchvision.models.detection.maskrcnn_resnet50_fpn(pretrained=True)\n",
    "segmentation_model.eval().to(device)\n",
    "\n",
    "tracker = BotSort(\n",
    "    reid_weights=Path('osnet_x0_25_msmt17.pt'),  # ReID model to use\n",
    "    device=device,\n",
    "    half=False,\n",
    ")\n",
    "\n",
    "# Open the video file\n",
    "vid = cv2.VideoCapture(0)\n",
    "\n",
    "# Function to generate a unique color for each track ID\n",
    "def get_color(track_id):\n",
    "    np.random.seed(int(track_id))\n",
    "    return tuple(np.random.randint(0, 255, 3).tolist())\n",
    "\n",
    "while True:\n",
    "    ret, im = vid.read()\n",
    "    if not ret:\n",
    "        break\n",
    "\n",
    "    # Convert frame to tensor and move to device\n",
    "    frame_tensor = torchvision.transforms.functional.to_tensor(im).unsqueeze(0).to(device)\n",
    "\n",
    "    # Run the Mask R-CNN model to detect bounding boxes and masks\n",
    "    with torch.no_grad():\n",
    "        results = segmentation_model(frame_tensor)[0]\n",
    "\n",
    "    # Extract detections (bounding boxes, masks, and scores)\n",
    "    dets = []\n",
    "    masks = []\n",
    "    confidence_threshold = 0.5\n",
    "\n",
    "    for i, score in enumerate(results['scores']):\n",
    "        if score >= confidence_threshold:\n",
    "            # Extract bounding box and score\n",
    "            x1, y1, x2, y2 = results['boxes'][i].cpu().numpy()\n",
    "            conf = score.item()\n",
    "            cls = results['labels'][i].item()  # Assuming 'labels' represents the class\n",
    "            dets.append([x1, y1, x2, y2, conf, cls])\n",
    "\n",
    "            # Extract mask and add to list\n",
    "            mask = results['masks'][i, 0].cpu().numpy()  # Use the first channel (binary mask)\n",
    "            masks.append(mask)\n",
    "\n",
    "    # Convert detections to a numpy array (N x (x, y, x, y, conf, cls))\n",
    "    dets = np.array(dets)\n",
    "\n",
    "    # Update tracker with detections and image\n",
    "    tracks = tracker.update(dets, im)  # M x (x, y, x, y, id, conf, cls, ind)\n",
    "\n",
    "    # Draw segmentation masks and bounding boxes in a single loop\n",
    "    if len(tracks) > 0:\n",
    "        inds = tracks[:, 7].astype('int')  # Get track indices as int\n",
    "\n",
    "        # Use the indices to match tracks with masks\n",
    "        if len(masks) > 0:\n",
    "            masks = [masks[i] for i in inds if i < len(masks)]  # Reorder masks to match the tracks\n",
    "\n",
    "        # Iterate over tracks and corresponding masks to draw them together\n",
    "        for track, mask in zip(tracks, masks):\n",
    "            track_id = int(track[4])  # Extract track ID\n",
    "            color = get_color(track_id)  # Use unique color for each track\n",
    "            \n",
    "            # Draw the segmentation mask on the image\n",
    "            if mask is not None:\n",
    "                # Binarize the mask\n",
    "                mask = (mask > 0.5).astype(np.uint8)\n",
    "                \n",
    "                # Blend mask color with the image\n",
    "                im[mask == 1] = im[mask == 1] * 0.5 + np.array(color) * 0.5\n",
    "\n",
    "            # Draw the bounding box\n",
    "            x1, y1, x2, y2 = track[:4].astype('int')\n",
    "            cv2.rectangle(im, (x1, y1), (x2, y2), color, 2)\n",
    "            \n",
    "            # Add text with ID, confidence, and class\n",
    "            conf = track[5]\n",
    "            cls = track[6]\n",
    "            cv2.putText(im, f'ID: {track_id}, Conf: {conf:.2f}, Class: {cls}', \n",
    "                        (x1, y1 - 10), cv2.FONT_HERSHEY_SIMPLEX, 0.5, color, 2)\n",
    "\n",
    "    # Display the image\n",
    "    cv2.imshow('Segmentation Tracking', im)\n",
    "\n",
    "    # Break on pressing q or space\n",
    "    key = cv2.waitKey(1) & 0xFF\n",
    "    if key == ord(' ') or key == ord('q'):\n",
    "        break\n",
    "\n",
    "vid.release()\n",
    "cv2.destroyAllWindows()"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "boxmot-YDNZdsaB-py3.11",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.11.5"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
